# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import itertools
import os
import subprocess
import sys

import jinja2

ARCHS = []
SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
    arch = arch[: arch.index(".") + 2].replace(".", "")
    arch = int(arch)
    # only SM89 and SM120 fully support
    # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
    # SM90 and SM100 can use this PTX, but it’s simulated
    # with FP16 MMA, so it cannot achieve any acceleration.
    if arch in [89, 120]:
        SUPPORT_FP8 = True

FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()

FILE_HEAD = (
    FILE_HEAD_COMMENT
    + """
#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
"""
)

TEMPLATE = (
    "template __global__ void Marlin<"
    "{{a_type_id}}, "
    "{{b_type_id}}, "
    "{{c_type_id}}, "
    "{{s_type_id}}, "
    "{{threads}}, "
    "{{thread_m_blocks}}, "
    "{{thread_n_blocks}}, "
    "{{thread_k_blocks}}, "
    "{{m_block_size_8}}, "
    "{{stages}}, "
    "{{group_blocks}}, "
    "{{is_zp_float}}>"
    "( MARLIN_KERNEL_PARAMS );"
)

THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]

QUANT_CONFIGS = [
    # AWQ-INT4
    {
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 2, 4, 8],
    },
    # HQQ
    {
        "a_type": ["kFloat16"],
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [4],
        "is_zp_float": True,
    },
    # GPTQ-INT4
    {
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 0, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": "kU8B128",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 0, 2, 4, 8],
    },
    # FP8
    {
        "b_type": "kFE4M3fn",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 8],
    },
    # NVFP4
    {
        "b_type": "kFE2M1f",
        "s_type": "kFE4M3fn",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [1],
    },
    # MXFP4
    {
        "a_type": ["kBFloat16"],
        "b_type": "kFE2M1f",
        "s_type": "kFE8M0fnu",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": ["kS8"],
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": ["kS8"],
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kFE2M1f",
        "c_type": ["kBFloat16"],
        "s_type": "kFE8M0fnu",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [2],
    },
]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])

    filename = os.path.dirname(__file__) + "/kernel_selector.h"
    subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    result_dict = {}

    for quant_config in QUANT_CONFIGS:
        c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
        a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
        b_type = quant_config["b_type"]
        is_zp_float = quant_config.get("is_zp_float", False)
        all_group_blocks = quant_config["group_blocks"]
        all_m_blocks = quant_config["thread_m_blocks"]
        all_thread_configs = quant_config["thread_configs"]

        for a_type, c_type in itertools.product(a_types, c_types):
            if not SUPPORT_FP8 and a_type == "kFE4M3fn":
                continue
            if "16" in a_type and "16" in c_type and a_type != c_type:
                continue
            s_type = quant_config.get("s_type", c_type)
            if (a_type, b_type, c_type) not in result_dict:
                result_dict[(a_type, b_type, c_type)] = []

            for group_blocks, m_blocks, thread_configs in itertools.product(
                all_group_blocks, all_m_blocks, all_thread_configs
            ):
                thread_k, thread_n, threads = thread_configs

                if threads == 256:
                    # for small batch (m_blocks == 1),
                    #     we only need (128, 128, 256)
                    # for large batch (m_blocks > 1),
                    #     we only need (64, 256, 256)
                    if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
                        continue
                    if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
                        continue

                config = {
                    "threads": threads,
                    "s_type": s_type,
                    "thread_m_blocks": max(m_blocks, 1),
                    "thread_k_blocks": thread_k // 16,
                    "thread_n_blocks": thread_n // 16,
                    "m_block_size_8": "true" if m_blocks == 0.5 else "false",
                    "stages": "pipe_stages",
                    "group_blocks": group_blocks,
                    "is_zp_float": "true" if is_zp_float else "false",
                }

                result_dict[(a_type, b_type, c_type)].append(config)

    kernel_selector_str = FILE_HEAD_COMMENT

    for (a_type, b_type, c_type), config_list in result_dict.items():
        all_template_str_list = []
        for config in config_list:
            s_type = config["s_type"]
            template_str = jinja2.Template(TEMPLATE).render(
                a_type_id=f"vllm::{a_type}.id()",
                b_type_id=f"vllm::{b_type}.id()",
                c_type_id=f"vllm::{c_type}.id()",
                s_type_id=f"vllm::{s_type}.id()",
                **config,
            )
            all_template_str_list.append(template_str)

            conditions = [
                f"a_type == vllm::{a_type}",
                f"b_type == vllm::{b_type}",
                f"c_type == vllm::{c_type}",
                f"s_type == vllm::{s_type}",
                f"threads == {config['threads']}",
                f"thread_m_blocks == {config['thread_m_blocks']}",
                f"thread_n_blocks == {config['thread_n_blocks']}",
                f"thread_k_blocks == {config['thread_k_blocks']}",
                f"m_block_size_8 == {config['m_block_size_8']}",
                f"group_blocks == {config['group_blocks']}",
                f"is_zp_float == {config['is_zp_float']}",
            ]
            conditions = " && ".join(conditions)

            if kernel_selector_str == FILE_HEAD_COMMENT:
                kernel_selector_str += f"if ({conditions})\n  kernel = "
            else:
                kernel_selector_str += f"else if ({conditions})\n  kernel = "

            kernel_template2 = (
                "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
                "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
                "{{thread_n_blocks}}, {{thread_k_blocks}}, "
                "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
                "{{is_zp_float}}>;"
            )

            kernel_selector_str += (
                jinja2.Template(kernel_template2).render(
                    a_type_id=f"vllm::{a_type}.id()",
                    b_type_id=f"vllm::{b_type}.id()",
                    c_type_id=f"vllm::{c_type}.id()",
                    s_type_id=f"vllm::{s_type}.id()",
                    **config,
                )
                + "\n"
            )

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        if a_type == "kFE4M3fn":
            filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
        else:
            filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"

        filename = filename.lower()

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)

    if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
        kernel_selector_str += (
            "else if (a_type == vllm::kFE4M3fn)\n"
            "  TORCH_CHECK(false, "
            '"marlin kernel with fp8 activation is not built.");'
        )

    with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
        f.write(kernel_selector_str)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()
